import bert_score
from rouge_chinese import Rouge
import jieba


class TextComparator:
    def compare(self, text1, text2):
        return text1 == text2


class BertScoreComparator(TextComparator):
    def compare(self, text1, text2):
        cand = [text1]
        ref = [text2]
        # 返回值: P, R, F1
        return bert_score.score(cand, ref, lang="zh", verbose=True)


class RougeComparator(TextComparator):
    def compare(self, text1, text2):
        hyp = ' '.join(jieba.cut(text1))
        ref = ' '.join(jieba.cut(text2))
        rouge = Rouge()
        scores = rouge.get_scores(hyp, ref)
        return scores[0]['rouge-l']